{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../docs/images/DSPy8.png\" alt=\"DSPy7 Image\" height=\"150\"/>\n",
    "\n",
    "## **DSPy Assertions**: Asserting Computational Constraints on Foundation Models\n",
    "\n",
    "### **TweetGen**: Generating tweets to answer questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[<img align=\"center\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" />](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/examples/tweets/tweets_assertions.ipynb)\n",
    "\n",
    "\n",
    "This notebook highlights an example of [**DSPy Assertions**](https://dspy-docs.vercel.app/docs/building-blocks/assertions), allowing for declaration of computational constraints within DSPy programs. \n",
    "\n",
    "\n",
    "This notebook builds upon the foundational concepts of the **DSPy** framework. Prerequisites of following this notebook is having gone through the [DSPy tutorial](../../intro.ipynb), the [**DSPy Assertions documentation**](https://dspy-docs.vercel.app/docs/building-blocks/assertions) and the introductory DSPy Assertions [tutorial on LongFormQA](../longformqa/longformqa_assertions.ipynb).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://huggingface.co/arnavs11/DSPy_TweetGen_Cache\n",
    "%cd DSPy_TweetGen_Cache/\n",
    "!git checkout master\n",
    "%cd ..\n",
    "import os\n",
    "repo_clone_path = '/content/DSPy_TweetGen_Cache'\n",
    "\n",
    "# Check if '/content' is writable\n",
    "if not os.access('/content', os.W_OK):\n",
    "    # If '/content' is not writable, choose an alternative directory\n",
    "    # Example: using a directory relative to the current working directory\n",
    "    repo_clone_path = os.path.join(os.getcwd(), 'DSPy_TweetGen_Cache')\n",
    "\n",
    "# Set up the cache for this notebook\n",
    "os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = repo_clone_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import os\n",
    "import regex as re\n",
    "import json\n",
    "\n",
    "try: # When on google Colab, let's clone the notebook so we download the cache.\n",
    "    import google.colab\n",
    "    repo_path = 'dspy'\n",
    "    \n",
    "    !git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path\n",
    "except:\n",
    "    repo_path = '.'\n",
    "\n",
    "if repo_path not in sys.path:\n",
    "    sys.path.append(repo_path)\n",
    "\n",
    "\n",
    "import pkg_resources # Install the package if it's not installed\n",
    "if not \"dspy-ai\" in {pkg.key for pkg in pkg_resources.working_set}:\n",
    "    !pip install -U pip\n",
    "    !pip install dspy-ai\n",
    "    !pip install openai~=0.28.1\n",
    "    !pip install -e $repo_path\n",
    "\n",
    "import dspy\n",
    "from dspy.predict import Retry\n",
    "from dspy.datasets import HotPotQA\n",
    "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n",
    "from dsp.utils import deduplicate\n",
    "from dspy.evaluate.evaluate import Evaluate\n",
    "from dspy.primitives.assertions import assert_transform_module, backtrack_handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
    "dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
    "turbo = dspy.OpenAI(model='gpt-3.5-turbo-0613', max_tokens=500)\n",
    "dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0, keep_details=True)\n",
    "trainset = [x.with_inputs('question', 'answer') for x in dataset.train]\n",
    "devset = [x.with_inputs('question', 'answer') for x in dataset.dev]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3] TweetGen\n",
    "\n",
    "Let's introduce a new task: TweetGen. We extend the `Multi-Hop QA` program, but now aim to present the answer generation in the form of a tweet. \n",
    "\n",
    "The `Tweeter` module captures the iterative multi-hop generation process from `Multi-Hop QA` in query generation, passage retrieval, and context assembly. The `GenerateTweet` layer now utilizes the context alongside the question to generate a tweet that effectively answers the question."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this program, we aim to generate tweets that adhere to the following guidelines:\n",
    "1. The tweet has no hashtags. \n",
    "2. The tweet includes the correct answer\n",
    "3. The tweet is within a character limit. \n",
    "4. The tweet is engaging\n",
    "5. The tweet is faithful"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenerateSearchQuery(dspy.Signature):\n",
    "    \"\"\"Write a simple search query that will help answer a complex question.\"\"\"\n",
    "    context = dspy.InputField(desc=\"may contain relevant facts\")\n",
    "    question = dspy.InputField()\n",
    "    query = dspy.OutputField()\n",
    "\n",
    "class GenerateTweet(dspy.Signature):\n",
    "    \"\"\"Generate an engaging tweet that effectively answers a question staying faithful to the context, is less than 280 characters, and has no hashtags.\"\"\"\n",
    "    question = dspy.InputField()\n",
    "    context = dspy.InputField(desc=\"may contain relevant facts\")\n",
    "    tweet = dspy.OutputField()\n",
    "\n",
    "class Tweeter(dspy.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.generate_tweet = dspy.ChainOfThought(GenerateTweet)\n",
    "\n",
    "    def forward(self, question, answer):\n",
    "        context = []\n",
    "        max_hops=2\n",
    "        passages_per_hop=3\n",
    "        generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]\n",
    "        retrieve = dspy.Retrieve(k=passages_per_hop)\n",
    "        for hop in range(max_hops):\n",
    "            query = generate_query[hop](context=context, question=question).query\n",
    "            passages = retrieve(query).passages\n",
    "            context = deduplicate(context + passages)\n",
    "        generated_tweet = self.generate_tweet(question=question, context=context).tweet\n",
    "        return dspy.Prediction(generated_tweet=generated_tweet, context=context)\n",
    "    \n",
    "tweeter = Tweeter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4] Evaluation - Intrinsic and Extrinsic\n",
    "\n",
    "#### Intrinsic Metrics: passing internal computational constraints is the goal \n",
    "\n",
    "**No Hashtags** - This is a user-personalized constraint to test how well the model can follow a specific, yet simple guideline of not including any hashtags within the generated tweet.\n",
    "\n",
    "**Correct Answer Inclusion** - This is a general check to ensure the tweet indeed has the correct answer to the question.\n",
    "\n",
    "**Within Length** - This check follows Twitter platform guidelines of 280 character limits per tweet.\n",
    "\n",
    "**Engagement** - To verify the engagement quality of the tweet, we define and call another **DSPy** program: ``Predict`` on ``AssessTweet``, relying on the same LM to answer the question: `\"Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging.\"`\n",
    "\n",
    "**Faithfulness** - To verify the faithfulness of the tweet to its referenced context, we similarly use `AssessTweet` as above but prompt it with the question: `\"Is the assessed text grounded in the context? Say no if it includes significant facts not in the context.\"`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def has_no_hashtags(text):\n",
    "    return len(re.findall(r\"#\\w+\", text)) == 0\n",
    "\n",
    "def is_within_length_limit(text, length_limit=280):\n",
    "    return len(text) <= length_limit\n",
    "\n",
    "def is_assessment_yes(assessment_answer):\n",
    "    \"\"\"Check if the first word of the assessment answer is 'yes'.\"\"\"\n",
    "    return assessment_answer.split()[0].lower() == 'yes'\n",
    "\n",
    "def has_correct_answer(text, answer):\n",
    "    return answer in text\n",
    "\n",
    "\n",
    "class AssessTweet(dspy.Signature):\n",
    "    \"\"\"Assess the quality of a tweet along the specified dimension.\"\"\"\n",
    "\n",
    "    context = dspy.InputField(desc='ignore if N/A')\n",
    "    assessed_text = dspy.InputField()\n",
    "    assessment_question = dspy.InputField()\n",
    "    assessment_answer = dspy.OutputField(desc=\"Yes or No\")\n",
    "\n",
    "def no_hashtags_metric(gold, pred, trace=None):\n",
    "    tweet = pred.generated_tweet\n",
    "    no_hashtags = has_no_hashtags(tweet)\n",
    "    score = no_hashtags\n",
    "    return score\n",
    "\n",
    "def is_correct_metric(gold, pred, trace=None):\n",
    "    answer, tweet = gold.answer, pred.generated_tweet\n",
    "    correct = has_correct_answer(tweet, answer)\n",
    "    score = correct\n",
    "    return score\n",
    "\n",
    "def within_length_metric(gold, pred, trace=None):\n",
    "    tweet = pred.generated_tweet\n",
    "    within_length_limit = is_within_length_limit(tweet, 280)\n",
    "    score = within_length_limit\n",
    "    return score\n",
    "\n",
    "def engaging_metric(gold, pred, trace=None):\n",
    "    tweet = pred.generated_tweet\n",
    "    engaging = \"Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging.\"\n",
    "    engaging = dspy.Predict(AssessTweet)(context='N/A', assessed_text=tweet, assessment_question=engaging)\n",
    "    engaging = engaging.assessment_answer.split()[0].lower() == 'yes'\n",
    "    score = engaging\n",
    "    return score\n",
    "\n",
    "def faithful_metric(gold, pred, trace=None):\n",
    "    context, tweet = pred.context, pred.generated_tweet\n",
    "    faithful = \"Is the assessed text grounded in the context? Say no if it includes significant facts not in the context.\"   \n",
    "    faithful = dspy.Predict(AssessTweet)(context=context, assessed_text=tweet, assessment_question=faithful)\n",
    "    faithful = faithful.assessment_answer.split()[0].lower() == 'yes'\n",
    "    score = faithful\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Extrinsic Metrics: Assess the overall quality and effectiveness of generated output on downstream task\n",
    "\n",
    "The extrinsic metric is defined as the overall quality of the generated tweet in following the mentioned constraints, and this is evaluated over a composite metric.\n",
    "\n",
    "While maintaining the most relevant intrinsic metrics of forming a valid tweet in the correctness and within_length constraints, the overall composite metric returns an averaged score over the 5 intrinsic metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def overall_metric(gold, pred, trace=None):\n",
    "    answer, context, tweet = gold.answer, pred.context, pred.generated_tweet\n",
    "    no_hashtags = has_no_hashtags(tweet)\n",
    "    within_length_limit = is_within_length_limit(tweet, 280)\n",
    "    correct = has_correct_answer(tweet, answer)\n",
    "    engaging = \"Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging.\"\n",
    "    faithful = \"Is the assessed text grounded in the context? Say no if it includes significant facts not in the context.\"   \n",
    "    faithful = dspy.Predict(AssessTweet)(context=context, assessed_text=tweet, assessment_question=faithful)\n",
    "    engaging = dspy.Predict(AssessTweet)(context='N/A', assessed_text=tweet, assessment_question=engaging)\n",
    "    engaging, faithful = [m.assessment_answer.split()[0].lower() == 'yes' for m in [engaging, faithful]]\n",
    "    score = (correct + engaging + faithful + no_hashtags + within_length_limit) if correct and within_length_limit else 0\n",
    "    return score / 5.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We hence define the evaluation as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [no_hashtags_metric, is_correct_metric, within_length_metric, engaging_metric, faithful_metric, overall_metric]\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(tweeter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at an example tweet generation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[118]\n",
    "tweet = tweeter(question=example.question, answer = example.answer)\n",
    "print(f'Generated Tweet: ', tweet.generated_tweet)\n",
    "tweet.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset[118:119], num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(tweeter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we see that the generated tweet is within the length of 280 characters at 151 characters. It does in fact include the correct answer `Hooke`.\n",
    "\n",
    "However, it fails to not include hashtags as we see `#knowledge` at the end of the tweet. Additionally, the tweet has been determined to not be engaging, which makes sense from an eye-test as it simply states the answer and nothing more. \n",
    "\n",
    "Let's try to fix this and produce tweets using DSPy Assertions. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5] Introducing Assertions: TweeterWithAssertions\n",
    "\n",
    "To correct these various errors, let's include assertions that simply reiterate our computational constraints within DSPy Assertion semantics. \n",
    "\n",
    "In the first **Assertion**, we check for if the generated tweet has any hashtags through regex and if violated, assert: **\"Please revise the tweet to remove hashtag phrases following it.\"**\n",
    "\n",
    "Similarly, we check for the tweet length and if it is not within 280 characters, we send the feedback message: **\"Please ensure the tweet is within {280} characters.\"**\n",
    "\n",
    "We check for if the generated tweet has the answer and if not, we assert: **\"The tweet does not include the correct answer to the question. Please revise accordingly.\"**\n",
    "\n",
    "For the engagement and faithfulness checks, we make use of the setup from above, checking for if the respective assessment is determined as `Yes` or `No`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TweeterWithAssertions(dspy.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.generate_tweet = dspy.ChainOfThought(GenerateTweet)\n",
    "\n",
    "    def forward(self, question, answer):\n",
    "        context = []\n",
    "        max_hops=2\n",
    "        passages_per_hop=3\n",
    "        generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]\n",
    "        retrieve = dspy.Retrieve(k=passages_per_hop)\n",
    "        for hop in range(max_hops):\n",
    "            query = generate_query[hop](context=context, question=question).query\n",
    "            passages = retrieve(query).passages\n",
    "            context = deduplicate(context + passages)\n",
    "        generated_tweet = self.generate_tweet(question=question, context=context).tweet\n",
    "        dspy.Suggest(has_no_hashtags(generated_tweet), f\"Please revise the tweet to remove hashtag phrases following it.\", target_module=GenerateTweet)\n",
    "        dspy.Suggest(is_within_length_limit(generated_tweet, 280), f\"Please ensure the tweet is within {280} characters.\", target_module=GenerateTweet)\n",
    "        dspy.Suggest(has_correct_answer(generated_tweet, answer), \"The tweet does not include the correct answer to the question. Please revise accordingly.\", target_module=GenerateTweet)\n",
    "        engaging_question = \"Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging.\"\n",
    "        engaging_assessment = dspy.Predict(AssessTweet)(context=context, assessed_text=generated_tweet, assessment_question=engaging_question)\n",
    "        dspy.Suggest(is_assessment_yes(engaging_assessment.assessment_answer), \"The text is not engaging enough. Please revise to make it more captivating.\", target_module=GenerateTweet)\n",
    "        faithful_question = \"Is the assessed text grounded in the context? Say no if it includes significant facts not in the context.\"\n",
    "        faithful_assessment = dspy.Predict(AssessTweet)(context='N/A', assessed_text=generated_tweet, assessment_question=faithful_question)\n",
    "        dspy.Suggest(is_assessment_yes(faithful_assessment.assessment_answer), \"The text contains unfaithful elements or significant facts not in the context. Please revise for accuracy.\", target_module=GenerateTweet)\n",
    "        return dspy.Prediction(generated_tweet=generated_tweet, context=context)\n",
    "\n",
    "tweeter_with_assertions = assert_transform_module(TweeterWithAssertions().map_named_predictors(Retry), backtrack_handler) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's evaluate the `TweeterWithAssertions` now over the devset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [no_hashtags_metric, is_correct_metric, within_length_metric, engaging_metric, faithful_metric, overall_metric]\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(tweeter_with_assertions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's take a look at how our generated tweet has improved with the addition of assertions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[118]\n",
    "tweet = tweeter_with_assertions(question=example.question, answer = example.answer)\n",
    "print(f'Generated Tweet: ', tweet.generated_tweet)\n",
    "tweet.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset[118:119], num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(tweeter_with_assertions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the tweet has improved significantly, following all of our set constraints! \n",
    "\n",
    "It no longer has hashtags, and is both engaging and faithful, while maintaining the inclusion of the correct answer within 280 characters. Exciting!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6] Compilation With Assertions\n",
    "\n",
    "We can leverage **DSPy**'s`BootstrapFewShotWithRandomSearch` optimizer, to automatically generate few-shot demonstrations and conduct a random search over the candidates to output the best compiled program. We evaluate this over the `overall_metric` composite metric. \n",
    "\n",
    "We can first evaluate this on `Tweeter` to see how compilation performs without the inclusion of assertions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
    "compiled_tweeter = teleprompter.compile(student = tweeter, teacher = tweeter, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_tweeter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we test the compilation on 2 settings with assertions:\n",
    "\n",
    "**Compilation with Assertions**: assertion-driven example bootstrapping and counterexample bootstrapping during compilation. Teacher has assertions while the student does not as the student learns from the teacher's assertion-driven bootstrapped examples. \n",
    "\n",
    "**Compilation + Inference with Assertions**: assertion-driven optimizations for both the teacher and student to offer enhanced assertion-driven outputs during both compilation and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
    "compiled_with_assertions_tweeter = teleprompter.compile(student=tweeter, teacher = tweeter_with_assertions, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_with_assertions_tweeter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6, num_threads=1)\n",
    "compiled_tweeter_with_assertions = teleprompter.compile(student=tweeter_with_assertions, teacher = tweeter_with_assertions, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_tweeter_with_assertions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dspy_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
